﻿using System;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;
using System.Linq;
using System.Text;

namespace ZhCun.SqlServerRestore
{
    class RestoreTools
    {
        string ConnString { set; get; }

        public void SetConnectConfig(string server, string uid, string pwd)
        {
            ConnString = $"server={server};database=master;uid={uid};pwd={pwd}";
        }

        public bool ConnectSuccess { set; get; }

        public bool TestConnect()
        {
            try
            {
                using (SqlConnection conn = new SqlConnection(ConnString))
                {
                    conn.Open();
                    ConnectSuccess = true;
                    return true;
                }
            }
            catch (Exception ex)
            {
                ConnectSuccess = false;
                throw ex;
            }
        }

        /// <summary>
        /// 逻辑文件
        /// 执行 RESTORE FILELISTONLY 命令返回的数据
        /// </summary>
        class OriginalFile
        {
            /// <summary>
            /// 逻辑文件名
            /// </summary>
            public string LogicalName { set; get; }
            /// <summary>
            /// 物理路径
            /// </summary>
            public string PhysicalName { set; get; }

            /// <summary>
            /// 类型,"D":数据文件,"L" 日志文件
            /// </summary>
            public string Type { set; get; }

            public string TypeExpandName
            {
                get
                {
                    string r = string.Empty;
                    switch (Type.ToUpper())
                    {
                        case "D":
                            r = "Data.mdf";
                            break;
                        case "L":
                            r = "Log.ldf";
                            break;
                        default:
                            break;
                    }
                    return r;
                }
            }
        }

        #region ado

        int ExecNonQuery(string sqlText, params SqlParameter[] param)
        {
            using (SqlConnection conn = new SqlConnection(ConnString))
            {
                using (SqlCommand cmd = new SqlCommand(sqlText, conn))
                {
                    if (param != null)
                        cmd.Parameters.AddRange(param);
                    cmd.CommandTimeout = 60;
                    conn.Open();
                    return cmd.ExecuteNonQuery();
                }
            }
        }

        DataSet GetDataSet(string sqlText, params SqlParameter[] param)
        {
            using (SqlConnection conn = new SqlConnection(ConnString))
            {
                using (SqlCommand cmd = new SqlCommand(sqlText, conn))
                {
                    if (param != null)
                        cmd.Parameters.AddRange(param);
                    DataSet ds = new DataSet();
                    using (SqlDataAdapter sda = new SqlDataAdapter(cmd))
                    {
                        sda.Fill(ds);
                        return ds;
                    }
                }
            }
        }

        DataTable ExecDataTable(string sqlText, params SqlParameter[] param)
        {
            return GetDataSet(sqlText, param).Tables[0];
        }

        object GetScalar(string sqlText, params SqlParameter[] param)
        {
            using (SqlConnection conn = new SqlConnection(ConnString))
            {
                using (SqlCommand cmd = new SqlCommand(sqlText, conn))
                {
                    conn.Open();
                    if (param != null)
                        cmd.Parameters.AddRange(param);
                    return cmd.ExecuteScalar();

                }
            }
        }

        #endregion

        List<OriginalFile> GetOriginalFileList(string backupFile)
        {
            List<OriginalFile> oFileList = new List<OriginalFile>();
            if (!ConnectSuccess)
            {
                return oFileList;
            }
            string sqlTxt = string.Format("RESTORE FILELISTONLY from disk ='{0}'", backupFile);
            DataTable dt = ExecDataTable(sqlTxt, null);

            for (int i = 0; i < dt.Rows.Count; i++)
            {
                OriginalFile model = new OriginalFile();
                model.LogicalName = dt.Rows[i]["LogicalName"].ToString();
                model.PhysicalName = dt.Rows[i]["PhysicalName"].ToString();
                model.Type = dt.Rows[i]["Type"].ToString();
                oFileList.Add(model);
            }
            return oFileList;
        }

        public bool DatabaseIsExist(string dbName)
        {
            string sqlTxt = string.Format("select count(1) From master.dbo.sysdatabases where name='{0}'", dbName);
            object o = GetScalar(sqlTxt, null);
            if (Convert.ToInt32(o) <= 0)
            {
                return false;
            }
            else
            {
                return true;
            }
        }

        void CreateDatabase(string dbName, string dbSavePath)
        {
            if (!DatabaseIsExist(dbName))
            {
                string sqlTxt = $@"
Create database {dbName} ON PRIMARY 
(
	NAME = {dbName}_data,
	FILENAME = '{dbSavePath}\{dbName}_data.mdf'
)  
LOG ON 
(
	NAME = {dbName}_log,
	FILENAME = '{dbSavePath}\{dbName}_log.ldf',
	size = 512KB
)
";  //执行创建数据库操作;
                ExecNonQuery(sqlTxt, null);
            }
        }

        public bool RestoreDB(RestoreModel model)
        {
            CreateDatabase(model.NewDBName, model.SaveFullPath);
            StringBuilder sb = new StringBuilder();
            sb.AppendFormat("restore database {0} From disk='{1}' with replace ", model.NewDBName, model.BackupFile);
            List<OriginalFile> list = GetOriginalFileList(model.BackupFile);
            for (int i = 0; i < list.Count; i++)
            {
                sb.AppendFormat(",move '{0}' to '{1}{2}_{3}'", list[i].LogicalName, model.SaveFullPath, model.NewDBName, list[i].TypeExpandName);
            }

            for (int i = 0; i < list.Count; i++)
            {
                sb.AppendLine();
                //更改新数据库的逻辑文件名
                sb.AppendFormat("alter database {0} modify file (name='{1}',newname='{0}_{2}')", model.NewDBName, list[i].LogicalName, list[i].TypeExpandName);
            }
            ExecNonQuery(sb.ToString(), null);
            return true;
        }
    }
}